Package org.apache.lucene.classification

Source Code of org.apache.lucene.classification.KNearestNeighborClassifier

/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.classification;

import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.index.AtomicReader;
import org.apache.lucene.queries.mlt.MoreLikeThis;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.BytesRef;

import java.io.IOException;
import java.io.StringReader;
import java.util.HashMap;
import java.util.Map;

/**
* A k-Nearest Neighbor classifier (see <code>http://en.wikipedia.org/wiki/K-nearest_neighbors</code>) based
* on {@link MoreLikeThis}
*
* @lucene.experimental
*/
public class KNearestNeighborClassifier implements Classifier<BytesRef> {

  private MoreLikeThis mlt;
  private String textFieldName;
  private String classFieldName;
  private IndexSearcher indexSearcher;
  private int k;

  /**
   * Create a {@link Classifier} using kNN algorithm
   *
   * @param k the number of neighbors to analyze as an <code>int</code>
   */
  public KNearestNeighborClassifier(int k) {
    this.k = k;
  }

  /**
   * {@inheritDoc}
   */
  @Override
  public ClassificationResult<BytesRef> assignClass(String text) throws IOException {
    Query q = mlt.like(new StringReader(text), textFieldName);
    TopDocs topDocs = indexSearcher.search(q, k);
    return selectClassFromNeighbors(topDocs);
  }

  private ClassificationResult<BytesRef> selectClassFromNeighbors(TopDocs topDocs) throws IOException {
    // TODO : improve the nearest neighbor selection
    Map<BytesRef, Integer> classCounts = new HashMap<BytesRef, Integer>();
    for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
      BytesRef cl = new BytesRef(indexSearcher.doc(scoreDoc.doc).getField(classFieldName).stringValue());
      if (cl != null) {
        Integer count = classCounts.get(cl);
        if (count != null) {
          classCounts.put(cl, count + 1);
        } else {
          classCounts.put(cl, 1);
        }
      }
    }
    double max = 0;
    BytesRef assignedClass = new BytesRef();
    for (BytesRef cl : classCounts.keySet()) {
      Integer count = classCounts.get(cl);
      if (count > max) {
        max = count;
        assignedClass = cl.clone();
      }
    }
    double score = max / (double) k;
    return new ClassificationResult<BytesRef>(assignedClass, score);
  }

  /**
   * {@inheritDoc}
   */
  @Override
  public void train(AtomicReader atomicReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
    this.textFieldName = textFieldName;
    this.classFieldName = classFieldName;
    mlt = new MoreLikeThis(atomicReader);
    mlt.setAnalyzer(analyzer);
    mlt.setFieldNames(new String[]{textFieldName});
    indexSearcher = new IndexSearcher(atomicReader);
  }
}
TOP

Related Classes of org.apache.lucene.classification.KNearestNeighborClassifier

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.